""" 

:version: 03/05/2020

"""
from hexapole_old import HexVector, Assembly
from Verlet import verletFlyer, loadFinal, rewind

import numpy as np
import logging
import matplotlib.pyplot as plt
import matplotlib as mpl
import FigureSetup
import os
import os.path
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from matplotlib.ticker import FormatStrFormatter



csfont = {'fontname':'Times New Roman','fontsize':12}

# Set up logging and message detail level. Set the level to logging.INFO for a
# quieter output.
logger = logging.getLogger()
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
ch.setFormatter(logging.Formatter('%(name)s - %(name)s - %(message)s'))
logging.getLogger().addHandler(ch)


def Guide (HA1e2ypos, HA3e4ypos, HA4zshift, heightBlade1, heightBlade2):

    ### Guide specifications
    HA1zpos = 241.15
    HA3zpos = HA1zpos+24.0+5.0+5.0 #HA1zpos+2f(MaxVel)+2*(halfwidth arrays)+(arbitrary spacing)
            
    ### double shift design
    h1 = HexVector('Bvec_ri3lh7B1415.h5', position=[0.0, HA1e2ypos, HA1zpos]) 
    h2 = HexVector('Bvec_ri3lh7B1415.h5', position=[0.0, HA1e2ypos, HA1zpos+2*focallength+0.0])
    h3 = HexVector('Bvec_ri3lh7B1415.h5', position=[0.0, HA3e4ypos, HA3zpos]) 
    h4 = HexVector('Bvec_ri3lh7B1415.h5', position=[0.0, HA3e4ypos, HA3zpos+2*focallength+HA4zshift])    
    hh = Assembly([h1, h2, h3, h4])
    return hh

# Parameters to vary
input_folder = r'D:\Zeeman\Guide simulations_new\Input'
output_folder = r'D:\Zeeman\Guide simulations_new\Output'

states=[0]  #not for streak plots
# states=[0]#,1,2,3]  #not for streak plots

focallength = 12.7   #12.7, 9.4, 6.7
posCoil12 = 226.6
detectionPos = 319.00 #350.00 #ion trap position (using ion trap area detection)

# Guide specifications
posBlade1 = 270.65
posBlade2 = 310.0
HA1e2ypos = 1.5
HA3e4ypos = 0.5 
HA4zshift = 0.0
heightBlade1 = 1.1
heightBlade2 = -0.2

# HA dimensions (to draw guide)
HAwidth = 7.0
HAradiusINT = 3.0
HAhalfwidth = (HAwidth/2)
HAradiusEXT = HAradiusINT+4.0
laserwidth=0.8 #1.0


# #########TestGuideSim.py##########################################################################################
# ###############################################################################################################
# ######### Generates ypos/TOF/Vel histograms and %particles/%loss for target/faster/slower particles, with noGuide/Guide 

pos_all = np.empty(shape=(0, 3), dtype=np.float64)
vel_all = np.empty(shape=(0, 3), dtype=np.float64)
times_all = np.empty(shape=(0, ), dtype=np.float64)
pos_notskimmed = np.empty(shape=(0, 3), dtype=np.float64)
vel_notskimmed = np.empty(shape=(0, 3), dtype=np.float64)
times_notskimmed = np.empty(shape=(0, ), dtype=np.float64)
pos_noGuide = np.empty(shape=(0, 3), dtype=np.float64)
vel_noGuide = np.empty(shape=(0, 3), dtype=np.float64)
times_noGuide = np.empty(shape=(0, ), dtype=np.float64)

# Dummy Guide for propagating particles in free flight (no guide simulation)
def NOGUIDE (HA1e2ypos, HA3e4ypos, HA4zshift, heightBlade1, heightBlade2):

    ### Guide specifications
    HA1zpos = 500 # should be far beyond the decelerator/guide so it is not in the way
    HA3zpos = HA1zpos+24.0+5.0+5.0 #HA1zpos+2f(MaxVel)+2*(halfwidth arrays)+(arbitrary spacing)
            
    ### double shift design
    h1 = HexVector('Bvec_ri3lh7B1415.h5', position=[0.0, HA1e2ypos, HA1zpos]) 
    h2 = HexVector('Bvec_ri3lh7B1415.h5', position=[0.0, HA1e2ypos, HA1zpos+2*focallength+0.0])
    h3 = HexVector('Bvec_ri3lh7B1415.h5', position=[0.0, HA3e4ypos, HA3zpos]) 
    h4 = HexVector('Bvec_ri3lh7B1415.h5', position=[0.0, HA3e4ypos, HA3zpos+2*focallength+HA4zshift])    
    hh = Assembly([h1, h2, h3, h4])
    return hh

filename = 'GuideAnalysis_Undec_vel350_DetPos319Output1_1p5_0p5_1p1_m0p2'
datafile = os.path.join(output_folder,
    filename + '.npz')
if os.path.exists(datafile):
    alldata = np.load(datafile)
     
    pos_all = alldata['pos_all'] 
    vel_all = alldata['vel_all']
    times_all = alldata['times_all'] 
    pos_notskimmed = alldata['pos_notskimmed']
    vel_notskimmed = alldata['vel_notskimmed']
    times_notskimmed = alldata['times_notskimmed'] 
    pos_noGuide = alldata['pos_noGuide'] 
    vel_noGuide = alldata['vel_noGuide']
    times_noGuide = alldata['times_noGuide'] 
    
    tof_time_guide = alldata['tof_time_guide'] 
    tof_signal_guide = alldata['tof_signal_guide'] 
    tof_time_noguide = alldata['tof_time_noguide']
    tof_signal_noguide = alldata['tof_signal_noguide'] 
else:
    #### Generate the guide
    hh = Guide(HA1e2ypos, HA3e4ypos, HA4zshift, heightBlade1, heightBlade2)   
    hh_NOGUIDE = NOGUIDE(HA1e2ypos, HA3e4ypos, HA4zshift, heightBlade1, heightBlade2)  
    for s in states:
        ############# Flying with skimming with 2 razor blades - detection within ion trap (vs plane detection)
        ### Load some atoms.
        pos, vel, times = loadFinal(input_folder, states=[s])
        ### Move atoms back to the middle of coil 12
        pos, vel, times = rewind(posCoil12, pos, vel, times) 
        print "Particles going to first blade"
        ### Fly the atoms through the assembly up to the position of blade 1
        pos, vel, times = verletFlyer(pos, vel, times, state=states[s], hexapole=hh, totalZ=posBlade1, dt=0.5, totalTime=500)
        ### Collect indices of the particles not skimmed by blade 1
        ind_notskimmed1 = np.where(pos[:,1]>heightBlade1)[0]
        print "Particles going to second blade"
        ### Fly the atoms through the assembly up to the position of blade 2
        pos, vel, times = verletFlyer(pos, vel, times, state=states[s], hexapole=hh, totalZ=posBlade2, dt=0.5, totalTime=500)
        ### Collect indices of the particles not skimmed by blade 2
        ind_notskimmed2 = np.where(pos[:,1]<heightBlade2)[0] 
        ### Fly the atoms through the rest of the assembly
        pos, vel, times = verletFlyer(pos, vel, times, state=states[s], hexapole=hh, totalZ=detectionPos, dt=0.5, totalTime=500)
        ### Collect indices of the particles arriving within the area of a circle at (0,0,detectionPos) with radius=1.3mm (i.e. they make it through the rods into the ion trap)
        # radiusTrapTarget = np.sqrt((pos[:,0]**2)+(pos[:,1]**2))
        # ind_insideTrap = np.where(radiusTrapTarget<1.3)[0] 
        #ind_insideTrap = np.where((pos[:,1]<laserwidth)&(pos[:,1]>-laserwidth))[0]
        ### Pick out particles that have not collided and ended up in ion trap detection area
        ind_notcollided = hh.notCollided(pos)
        #ind_notColl = reduce(np.intersect1d, (ind_notcollided, ind_insideTrap))
        ind_notColl = ind_notcollided
        ### Get indices of the particles that went through both blades and into the ion trap
        ind_notskimmed = reduce(np.intersect1d, (ind_notskimmed1, ind_notskimmed2))
        ind_notSkimColl = reduce(np.intersect1d, (ind_notskimmed, ind_notColl))
        ###Record pos, vel, times of all particles that have not collided
        pos_all = np.concatenate((pos_all, pos[ind_notColl,:]))
        vel_all = np.concatenate((vel_all, vel[ind_notColl,:]))
        times_all = np.concatenate((times_all, times[ind_notColl]))
        ###Record pos, vel, times of the particles that have not collided nor have been skimmed off
        pos_notskimmed = np.concatenate((pos_notskimmed, pos[ind_notSkimColl,:]))
        vel_notskimmed = np.concatenate((vel_notskimmed, vel[ind_notSkimColl,:]))
        times_notskimmed = np.concatenate((times_notskimmed, times[ind_notSkimColl]))
        
        ############################ Free flight with no guide after decelerator - detection within ion trap
        ### Load some atoms.
        pos, vel, times = loadFinal(input_folder, states=[s])

        # Move particles back to the position of the 12th coil
        pos, vel, times = rewind(posCoil12, pos, vel, times)
        # Propagate particles to the the detection position. The hexapole uses the dummy guide placed outside of the range of the ion trap 
        # (so should not be in the physical way of propagation)

        # Particles are propagated as -1 particles so should not be affected by B-fields. 
        pos, vel, times = verletFlyer(pos, vel, times, state=-1, hexapole=hh_NOGUIDE, totalZ=detectionPos, dt=0.5, totalTime=500)

        pos_noGuide = np.concatenate((pos_noGuide, pos))
        vel_noGuide = np.concatenate((vel_noGuide, vel))
        times_noGuide = np.concatenate((times_noGuide, times))
        print s
	
	#Fly all particles forwards, as if they continued flying through the detection plane (as in SimExpCf.py)
	pos_final_guide = pos_notskimmed + (np.max(times_notskimmed) - np.transpose([times_notskimmed]*3)) * vel_notskimmed
	pos_final_noGuide = pos_noGuide + (np.max(times_noGuide) - np.transpose([times_noGuide]*3)) * vel_noGuide	
	
	def detect_at(plane, max_time, pos_final, vel, laser_width=1.0, laser_hight=1.0, nsteps=1200, dt=1):
		tof_signal = np.zeros(nsteps)
		tof_time =  max_time - np.arange(nsteps)*dt
		for i in range(nsteps):
			 ## Move particles
			 pos_fly = pos_final - i*dt*vel
			 ## Detect within a Gaussian ellipse
			 detected = (np.exp(-((pos_fly[:,2]-plane)**2)/(2.0 * laser_width**2))
				 * np.exp(-((pos_fly[:,1])**2)/(2.0 * (laser_hight)**2)))
			 tof_signal[i] = np.sum(detected)
		return tof_time, tof_signal
	
	detect_params = {'laser_width' : 0.8,'laser_hight' : 0.4,'nsteps' : 1000,'dt' : 5.0}
		
	tof_time_guide, tof_signal_guide = detect_at(detectionPos, np.max(times_notskimmed), pos_final_guide, vel_notskimmed, **detect_params)
    tof_time_noguide, tof_signal_noguide = detect_at(detectionPos, np.max(times_noGuide), pos_final_noGuide, vel_noGuide, **detect_params)
	  
    np.savez(datafile, pos_all=pos_all, vel_all=vel_all, times_all=times_all, pos_notskimmed=pos_notskimmed, vel_notskimmed=vel_notskimmed, 
	times_notskimmed=times_notskimmed, pos_noGuide=pos_noGuide, vel_noGuide=vel_noGuide, times_noGuide=times_noGuide, 
	tof_time_guide=tof_time_guide, tof_signal_guide=tof_signal_guide, tof_time_noguide=tof_time_noguide, tof_signal_noguide=tof_signal_noguide)


tof_norm=np.float(np.max(tof_signal_noguide))
sim_delay = 40

fig, axs = FigureSetup.new_figure(nrows=1, ncols=1)
plt.xlim([350, 1250])
# plt.ylim([0, 0.015])
plt.ylim([-1.05, 1.05])
# plt.yticks([0, 0.005, 0.01])
plt.yticks([-1.00,-0.5,0.00,0.5,1.0], ["1.0","0.5","0.0","0.5","1.0"])
plt.ylabel('Intensity (norm.)', **csfont)
plt.xlabel('TOF ($\mu$s)', **csfont)


axs.fill_between(tof_time_noguide+sim_delay, 0, -tof_signal_noguide/tof_norm, facecolor='#db9d29', alpha=0.6, label = "No Guide")
# axs.fill_between(tof_time_guide+sim_delay, 0, -tof_signal_guide/tof_norm, facecolor= '#158c22', alpha=0.6, label = "Guide")

axs.annotate("Experiment", (1000,0.20), textcoords = 'data', size = 11)
axs.annotate("Simulation", (1000,-0.25), textcoords = 'data', size = 11)
axs.legend(loc = 'upper right', prop={'size': 10})

plt.show()